{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ternary Classifier Plots\n",
    "\n",
    "Uses predictions to generate plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import numpy as np\n",
    "import sqlite3\n",
    "import pickle\n",
    "import json\n",
    "import math\n",
    "import random\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# saved data created with `ternary_classifier_predictions.py`\n",
    "image_pkl_filename = \"ternary_classifier_predictions.pkl\"\n",
    "\n",
    "# READ PKL\n",
    "with open(image_pkl_filename, \"rb\") as read_file:\n",
    "    image_data = pickle.load(read_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"paper\"\n",
    "# mode = \"test\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = image_data[:]\n",
    "arr = np.array(image_data)\n",
    "arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chosen categories\n",
    "selected_cats = [\"cs.CV\", \"stat.ML\", \"nlin.CG\", \"cs.GR\", \"cs.AI\", \"astro-ph\", \n",
    "                 \"astro-ph.GA\", \"astro-ph.IM\", \"cond-mat.str-el\", \"cs.LG\", \n",
    "                 \"cs.IT\", \"math-ph\"]\n",
    "selected_cats.sort()\n",
    "scats = selected_cats\n",
    "# scats = np.array(selected_cats)\n",
    "# scats.shape\n",
    "# scats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cats = [x[0] for x in arr[:, :1][:]]\n",
    "# cats = np.array(cats)\n",
    "cats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cats = np.array(cats)\n",
    "# cats.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find indexes of these categories\n",
    "index = np.where(scats == cats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = []\n",
    "\n",
    "for scat in scats:\n",
    "    ind = cats.index(scat)\n",
    "    print(ind)\n",
    "    index.append(ind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# random indexes\n",
    "# index = np.random.choice(arr.shape[0], 15, replace=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prep data for plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if mode == \"paper\":\n",
    "    test_data = arr[index]\n",
    "elif mode == \"test\":\n",
    "    test_data = arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate plot\n",
    "\n",
    "Loop over all of the data and retrieve the category, year and predictions. Calculate the totals for each class prediction and convert to percentages for diagram, sensor and unsure. Plot each category to a subplot with bars for each year within that category. Delete additional plots, add title and legend, then save.\n",
    "\n",
    "See below for updated version for paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xdim = 15\n",
    "ydim = 12\n",
    "\n",
    "fig, ax = plt.subplots(ydim, xdim)\n",
    "fig.subplots_adjust(hspace=0.2, wspace=0.2)\n",
    "fig.set_size_inches(40, 30)\n",
    "\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "\n",
    "dt = 0\n",
    "st = 0\n",
    "ut = 0\n",
    "\n",
    "for i, cat in enumerate(test_data[:]):\n",
    "    \n",
    "    y = math.floor(i / xdim)\n",
    "    x = i - (y * xdim)\n",
    "    print(x, y)\n",
    "    \n",
    "    print(\"category - \",cat[0])\n",
    "    \n",
    "    diagrams = []\n",
    "    sensors = []\n",
    "    unsures = []\n",
    "    sum_totals = []\n",
    "    \n",
    "#     a = np.array([])\n",
    "    \n",
    "    for class_totals in cat[3]:\n",
    "        print(class_totals)\n",
    "#         np.append(a, class_totals)\n",
    "        year_total = 0\n",
    "        for res in class_totals:\n",
    "            year_total += res\n",
    "        sum_totals.append(year_total)\n",
    "        \n",
    "        diagrams.append(class_totals[0])\n",
    "        sensors.append(class_totals[1])\n",
    "        unsures.append(class_totals[2])\n",
    "        \n",
    "        dt += class_totals[0]\n",
    "        st += class_totals[1]\n",
    "        ut += class_totals[2]\n",
    "                \n",
    "    print(\"t\",sum_totals)\n",
    "    print(\"d\",diagrams)\n",
    "    print(\"s\",sensors)\n",
    "    print(\"u\",unsures)\n",
    "    \n",
    "    diagram_percentages = [i / j * 100 for i, j in zip(diagrams, sum_totals)]\n",
    "    sensor_percentages = [i / j * 100 for i, j in zip(sensors, sum_totals)]\n",
    "    unsure_percentages = [i / j * 100 for i, j in zip(unsures, sum_totals)]\n",
    "    \n",
    "    print(\"d%\",diagram_percentages)\n",
    "    print(\"s%\",sensor_percentages)\n",
    "    print(\"u%\",unsure_percentages)\n",
    "    \n",
    "#     print(a)\n",
    "\n",
    "#     barWidth = 5 / len(diagram_percentages)\n",
    "#     print(\"barWidth:\",barWidth)\n",
    "\n",
    "   \n",
    "    indexes = [i for i,_ in enumerate(diagrams)]\n",
    "    print(\"indexes:\",indexes)\n",
    "    \n",
    "#     ax[i, j].plot(data[idx][1], data[idx][2], '--r.')\n",
    "#     ax[i, j].title.set_text(data[idx][0])\n",
    "            \n",
    "    ax[y,x].bar(indexes, diagram_percentages, color='#26bfb8', edgecolor='white')\n",
    "    ax[y,x].bar(indexes, sensor_percentages, bottom=diagram_percentages, color='#e6d929', edgecolor='white')\n",
    "    ax[y,x].bar(indexes, unsure_percentages, bottom=[i+j for i,j in zip(diagram_percentages, sensor_percentages)], color='#e62929', edgecolor='white')\n",
    "\n",
    "    ax[y,x].set_xticklabels([])\n",
    "    ax[y,x].set_yticklabels([])\n",
    "    \n",
    "    ax[y,x].title.set_text(cat[0])\n",
    "\n",
    "    print(\"*\" * 20)\n",
    "\n",
    "fig.suptitle(\"Ternary classifier predictions on arXiv primary categories\", x=0.5, y=0.92, size=28)\n",
    "\n",
    "for i in range(len(test_data), xdim * ydim):\n",
    "    print(\"i:\",i)\n",
    "    y = math.floor(i / xdim)\n",
    "    x = i - (y * xdim)\n",
    "    print(x, y)\n",
    "    \n",
    "    fig.delaxes(ax[y][x])\n",
    "\n",
    "colors = {'unsure':'#e62929', 'sensor':'#e6d929', 'diagram':'#26bfb8'}         \n",
    "labels = list(colors.keys())\n",
    "handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in labels]\n",
    "# ax[11, 14].legend(handles, labels, loc=\"lower right\")\n",
    "fig.legend(handles, labels, loc=(0.908,0.08))\n",
    "    \n",
    "fig.savefig(\"plot_ternary_classifier_predictions.svg\", dpi=300, bbox_inches='tight', pad_inches=0.5)\n",
    "\n",
    "print(\"dt:\",dt)\n",
    "print(\"st:\",st)\n",
    "print(\"ut:\",ut)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Methods paper version\n",
    "\n",
    "Updated to only show key categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xdim = 4\n",
    "ydim = 3\n",
    "\n",
    "fig, ax = plt.subplots(ydim, xdim)\n",
    "fig.subplots_adjust(hspace=0.2, wspace=0.2)\n",
    "fig.set_size_inches(16, 12)\n",
    "\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "\n",
    "dt = 0\n",
    "st = 0\n",
    "ut = 0\n",
    "\n",
    "for i, cat in enumerate(test_data[:]):\n",
    "    \n",
    "    y = math.floor(i / xdim)\n",
    "    x = i - (y * xdim)\n",
    "    print(x, y)\n",
    "    \n",
    "    print(\"category - \",cat[0])\n",
    "    \n",
    "    diagrams = []\n",
    "    sensors = []\n",
    "    unsures = []\n",
    "    sum_totals = []\n",
    "    \n",
    "#     a = np.array([])\n",
    "    \n",
    "    for class_totals in cat[3]:\n",
    "        print(class_totals)\n",
    "#         np.append(a, class_totals)\n",
    "        year_total = 0\n",
    "        for res in class_totals:\n",
    "            year_total += res\n",
    "        sum_totals.append(year_total)\n",
    "        \n",
    "        diagrams.append(class_totals[0])\n",
    "        sensors.append(class_totals[1])\n",
    "        unsures.append(class_totals[2])\n",
    "        \n",
    "        dt += class_totals[0]\n",
    "        st += class_totals[1]\n",
    "        ut += class_totals[2]\n",
    "                \n",
    "    print(\"t\",sum_totals)\n",
    "    print(\"d\",diagrams)\n",
    "    print(\"s\",sensors)\n",
    "    print(\"u\",unsures)\n",
    "    \n",
    "    diagram_percentages = [i / j * 100 for i, j in zip(diagrams, sum_totals)]\n",
    "    sensor_percentages = [i / j * 100 for i, j in zip(sensors, sum_totals)]\n",
    "    unsure_percentages = [i / j * 100 for i, j in zip(unsures, sum_totals)]\n",
    "    \n",
    "    print(\"d%\",diagram_percentages)\n",
    "    print(\"s%\",sensor_percentages)\n",
    "    print(\"u%\",unsure_percentages)\n",
    "    \n",
    "#     print(a)\n",
    "\n",
    "#     barWidth = 5 / len(diagram_percentages)\n",
    "#     print(\"barWidth:\",barWidth)\n",
    "\n",
    "   \n",
    "    indexes = [i for i,_ in enumerate(diagrams)]\n",
    "    print(\"indexes:\",indexes)\n",
    "    \n",
    "#     ax[i, j].plot(data[idx][1], data[idx][2], '--r.')\n",
    "#     ax[i, j].title.set_text(data[idx][0])\n",
    "            \n",
    "    ax[y,x].bar(indexes, diagram_percentages, color='#26bfb8', edgecolor='white')\n",
    "    ax[y,x].bar(indexes, sensor_percentages, bottom=diagram_percentages, color='#e6d929', edgecolor='white')\n",
    "    ax[y,x].bar(indexes, unsure_percentages, bottom=[i+j for i,j in zip(diagram_percentages, sensor_percentages)], color='#e62929', edgecolor='white')\n",
    "\n",
    "    ax[y,x].set_xticklabels([])\n",
    "    ax[y,x].set_yticklabels([])\n",
    "    \n",
    "    ax[y,x].title.set_text(cat[0])\n",
    "\n",
    "    print(\"*\" * 20)\n",
    "\n",
    "# fig.suptitle(\"Ternary classifier predictions on arXiv primary categories\", x=0.5, y=0.92, size=28)\n",
    "\n",
    "for i in range(len(test_data), xdim * ydim):\n",
    "    print(\"i:\",i)\n",
    "    y = math.floor(i / xdim)\n",
    "    x = i - (y * xdim)\n",
    "    print(x, y)\n",
    "    \n",
    "    fig.delaxes(ax[y][x])\n",
    "\n",
    "colors = {'mixed':'#e62929', 'sensor':'#e6d929', 'diagram':'#26bfb8'}         \n",
    "labels = list(colors.keys())\n",
    "handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in labels]\n",
    "# ax[11, 14].legend(handles, labels, loc=\"lower right\")\n",
    "# fig.legend(handles, labels, loc=(0.91,0.08))\n",
    "fig.legend(handles, labels, loc=(0.92,0.075))\n",
    "# fig.legend(handles, labels, loc=(0.073,0.025))\n",
    "    \n",
    "# fig.savefig(\"plot_ternary_classifier_predictions.svg\", dpi=300, bbox_inches='tight', pad_inches=0.5)\n",
    "\n",
    "print(\"dt:\",dt)\n",
    "print(\"st:\",st)\n",
    "print(\"ut:\",ut)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"plot_ternary_classifier_predictions_subset_mixed.svg\", dpi=300, bbox_inches='tight', pad_inches=0.1)\n",
    "fig.savefig(\"plot_ternary_classifier_predictions_subset_mixed.png\", dpi=300, bbox_inches='tight', pad_inches=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get totals for whole dataset\n",
    "tt = dt + st + ut\n",
    "dper = dt/tt\n",
    "sper = st/tt\n",
    "uper = ut/tt\n",
    "print(tt)\n",
    "print(dper)\n",
    "print(sper)\n",
    "print(uper)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
